{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bb22605d-8184-4ea3-895f-3ec5f845bcb7",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/cselab/odil/blob/main/examples/basic/fields.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37c6d474-fd03-440a-bb02-53dea2abc77e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "if 'google.colab' in sys.modules and not 'odil' in sys.modules:\n",
    "    %pip install -q git+https://github.com/cselab/odil\n",
    "import argparse\n",
    "import numpy as np\n",
    "import odil\n",
    "from odil import printlog\n",
    "def operator(ctx):\n",
    "    res = []\n",
    "\n",
    "    def func(x, y):\n",
    "        return x * 0.25 + y * 0.5\n",
    "\n",
    "    # Cells.\n",
    "    xc, yc = ctx.points(loc='cc')\n",
    "    uc = ctx.field('uc')\n",
    "    res += [('uc', uc - func(xc, yc))]\n",
    "\n",
    "    # Nodes.\n",
    "    xn, yn = ctx.points(loc='nn')\n",
    "    un = ctx.field('un')\n",
    "    res += [('un', un - func(xn, yn))]\n",
    "\n",
    "    # Faces in x.\n",
    "    xfx, yfx = ctx.points(loc='nc')\n",
    "    ufx = ctx.field('ufx')\n",
    "    res += [('ufx', ufx - func(xfx, yfx))]\n",
    "\n",
    "    # Faces in y.\n",
    "    xfy, yfy = ctx.points(loc='cn')\n",
    "    ufy = ctx.field('ufy')\n",
    "    res += [('ufy', ufy - func(xfy, yfy))]\n",
    "    return res\n",
    "\n",
    "\n",
    "def parse_args():\n",
    "    parser = argparse.ArgumentParser(\n",
    "        formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "    parser.add_argument('--Nx', type=int, default=8, help=\"Grid size in x\")\n",
    "    parser.add_argument('--Ny', type=int, default=4, help=\"Grid size in y\")\n",
    "    parser.add_argument('--plot', type=int, default=1, help=\"Plot fields\")\n",
    "    odil.util.add_arguments(parser)\n",
    "    odil.linsolver.add_arguments(parser)\n",
    "    parser.set_defaults(outdir='out_fields')\n",
    "    parser.set_defaults(multigrid=1)\n",
    "    parser.set_defaults(optimizer='adamn', lr=1e-2)\n",
    "    parser.set_defaults(frames=1,\n",
    "                        plot_every=100,\n",
    "                        report_every=10,\n",
    "                        history_every=10)\n",
    "    return parser.parse_args()\n",
    "\n",
    "\n",
    "def plot(problem, state, epoch, frame, cbinfo=None):\n",
    "    from odil import plotutil\n",
    "    import matplotlib.pyplot as plt\n",
    "    domain = problem.domain\n",
    "    fig, ax = plt.subplots(figsize=(4, 2))\n",
    "    kw = dict(vmin=0, vmax=1, cmap='Greys', clip_on=False, lw=0.5)\n",
    "\n",
    "    # Cells.\n",
    "    xc, yc = domain.points(loc='cc')\n",
    "    uc = domain.field(state, 'uc')\n",
    "    ax.scatter(xc, yc, s=10, c=uc, edgecolor='C0', label='uc', **kw)\n",
    "\n",
    "    # Nodes.\n",
    "    xn, yn = domain.points(loc='nn')\n",
    "    un = domain.field(state, 'un')\n",
    "    ax.scatter(xn, yn, s=10, c=un, edgecolor='C1', label='un', **kw)\n",
    "\n",
    "    # Faces in x.\n",
    "    xfx, yfx = domain.points(loc='nc')\n",
    "    ufx = domain.field(state, 'ufx')\n",
    "    ax.scatter(xfx, yfx, s=10, c=ufx, edgecolor='C2', label='ufx', **kw)\n",
    "\n",
    "    # Faces in y.\n",
    "    xfy, yfy = domain.points(loc='cn')\n",
    "    ufy = domain.field(state, 'ufy')\n",
    "    ax.scatter(xfy, yfy, s=10, c=ufy, edgecolor='C3', label='ufy', **kw)\n",
    "\n",
    "    ax.legend(loc='lower left',\n",
    "              bbox_to_anchor=(0.1, 1),\n",
    "              ncol=4,\n",
    "              handletextpad=0)\n",
    "\n",
    "    ax.pcolormesh(xn,\n",
    "                  yn,\n",
    "                  uc,\n",
    "                  edgecolor='k',\n",
    "                  shading='flat',\n",
    "                  zorder=0,\n",
    "                  **dict(kw, lw=0.5))\n",
    "\n",
    "    ax.set_aspect('equal')\n",
    "    ax.set_axis_off()\n",
    "    display(fig)\n",
    "\n",
    "\n",
    "def make_problem(args):\n",
    "    dtype = np.float64 if args.double else np.float32\n",
    "    domain = odil.Domain(cshape=(args.Nx, args.Ny),\n",
    "                         dimnames=['x', 'y'],\n",
    "                         lower=(0, 0),\n",
    "                         upper=(2, 1),\n",
    "                         dtype=dtype,\n",
    "                         multigrid=args.multigrid,\n",
    "                         mg_interp=args.mg_interp,\n",
    "                         mg_axes=[True, True],\n",
    "                         mg_nlvl=args.nlvl)\n",
    "    xx, yy = domain.points('x', 'y', loc='cn')\n",
    "    ixx, iyy = domain.indices('x', 'y', loc='cc')\n",
    "\n",
    "    from odil import Field\n",
    "\n",
    "    state = odil.State(\n",
    "        fields={\n",
    "            'uc': Field(np.zeros(domain.size(loc='cc')), loc='cc'),\n",
    "            'un': Field(np.zeros(domain.size(loc='nn')), loc='nn'),\n",
    "            'ufx': Field(np.zeros(domain.size(loc='nc')), loc='nc'),\n",
    "            'ufy': Field(np.zeros(domain.size(loc='cn')), loc='cn'),\n",
    "            'net': domain.make_neural_net([2, 4, 2]),\n",
    "        })\n",
    "    state = domain.init_state(state)\n",
    "    problem = odil.Problem(operator, domain)\n",
    "    return problem, state\n",
    "\n",
    "sys.argv = ['fields']\n",
    "args = parse_args()\n",
    "odil.setup_outdir(args)\n",
    "problem, state = make_problem(args)\n",
    "callback = odil.make_callback(problem,\n",
    "                                  args,\n",
    "                                  plot_func=plot if args.plot else None)\n",
    "odil.util.optimize_grad(args, args.optimizer, problem, state, callback)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
